{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_digits: 10, \t n_samples 1797, \t n_features 64\n",
      "__________________________________________________________________________________\n",
      "init\t\ttime\tinertia\thomo\tcompl\tv-meas\tARI\tAMI\tsilhouette\n"
     ]
    }
   ],
   "source": [
    "from time import time\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn import metrics\n",
    "from sklearn.cluster import KMeans,DBSCAN,AgglomerativeClustering\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import scale\n",
    "from sklearn import mixture\n",
    "np.random.seed(42)\n",
    "\n",
    "digits = load_digits()\n",
    "data = scale(digits.data)\n",
    "\n",
    "n_samples, n_features = data.shape\n",
    "n_digits = len(np.unique(digits.target))\n",
    "labels = digits.target\n",
    "\n",
    "sample_size = 2000\n",
    "\n",
    "print(\"n_digits: %d, \\t n_samples %d, \\t n_features %d\"\n",
    "      % (n_digits, n_samples, n_features))\n",
    "\n",
    "\n",
    "print(82 * '_')\n",
    "print('init\\t\\ttime\\tinertia\\thomo\\tcompl\\tv-meas\\tARI\\tAMI\\tsilhouette')\n",
    "\n",
    "\n",
    "def bench(estimator, name, data):\n",
    "    t0 = time()\n",
    "    estimator.fit(data)\n",
    "    print('%-9s\\t%.2fs\\t%.3f\\t%.3f\\t%.3f\\t%.3f\\t%.3f\\t%.3f'\n",
    "          % (name, (time() - t0), \n",
    "             metrics.homogeneity_score(labels, estimator.labels_),\n",
    "             metrics.completeness_score(labels, estimator.labels_),\n",
    "             metrics.v_measure_score(labels, estimator.labels_),\n",
    "             metrics.adjusted_rand_score(labels, estimator.labels_),\n",
    "             metrics.adjusted_mutual_info_score(labels,  estimator.labels_,\n",
    "                                                average_method='arithmetic'),\n",
    "             metrics.silhouette_score(data, estimator.labels_,\n",
    "                                      metric='euclidean',\n",
    "                                      sample_size=sample_size)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "k-means++\t0.25s\t0.602\t0.650\t0.625\t0.465\t0.621\t0.144\n"
     ]
    }
   ],
   "source": [
    "bench(KMeans(init='k-means++', n_clusters=n_digits, n_init=10),\n",
    "              name=\"k-means++\", data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DBSCAN   \t0.62s\t0.003\t0.132\t0.006\t0.000\t0.004\t0.636\n"
     ]
    }
   ],
   "source": [
    "bench(DBSCAN(eps=10, min_samples=5),name=\"DBSCAN\",data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hierarchical clustering\t0.19s\t0.758\t0.836\t0.796\t0.664\t0.793\t0.125\n"
     ]
    }
   ],
   "source": [
    "bench(AgglomerativeClustering(n_clusters=n_digits).fit(X),name=\"Hierarchical clustering\",data=data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GMM      \t1.45s\t0.598\t0.701\t0.645\t0.427\t0.642\t0.109\n"
     ]
    }
   ],
   "source": [
    "class GMM(mixture.GaussianMixture):\n",
    "    def __init__(self,**kwargs):\n",
    "        self.labels_=None\n",
    "        return super().__init__(**kwargs)\n",
    "    def fit(self,data):\n",
    "        super().fit(data)\n",
    "        self.labels_=self.predict(data)\n",
    "        \n",
    "bench(GMM(n_components=n_digits),name=\"GMM\",data=data)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
